# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
# generate kernel instances to speed up compilation
import copy
from dataclasses import dataclass, field
import fnmatch
import itertools
from pathlib import Path
from typing import List, Optional, Tuple

from codegen.cmake_config import GEN_DIR
from codegen.cpp_symbol_map import (
    MODE_MAP,
    LAYOUT_MAP,
    BIAS_CHECK_MAP,
    get_mask_check_map,
    get_mask_map,
    BIAS_MAP,
    FWD_DTYPE_MAP,
    BOOL_MAP,
    PIPELINE_ENUM_MAP,
    QSCALE_CHECK_MAP,
    QSCALE_MAP,
)
from codegen.utils import update_file


DTYPE_BITS = {"fp32": 32, "fp16": 16, "bf16": 16, "fp8": 8, "bf8": 8}

K0_MAX_SUBMAX_MAP = {32: 32, 64: 64, 96: 128, 128: 128, 256: 256}

FMHA_BATCH_PREFILL_PIPELINE_MAP = {
    "qr_async": "ck_tile::BlockFmhaBatchPrefillPipelineQRKSVSAsync",
}

FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.\n
// auto generated by generate.py
#include "ck_tile/ops/fmha/block/variants.hpp"
#include "fmha_fwd.hpp"
"""

FMHA_FWD_KERNEL_BODY = """
using fmha_dtype_{F_idx} = {F_dtype};

using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>;

using fmha_shape_{F_idx} = ck_tile::TileFmhaShape<fmha_block_tile_{F_idx},
                                      ck_tile::sequence<{F_rm0}, {F_rn0}, {F_rk0}>,
                                      ck_tile::sequence<{F_wm0}, {F_wn0}, {F_wk0}>,
                                      ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>,
                                      ck_tile::sequence<{F_wm1}, {F_wn1}, {F_wk1}>,
                                      {F_vlayout}>;

using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad},
                                                    {F_skpad},
                                                    {F_dpad},
                                                    {F_dvpad},
                                                    {F_logits},
                                                    {F_bias},
                                                    false,
                                                    {F_lse},
                                                    {F_dropout},
                                                    {F_qscale},
                                                    {F_occupancy}>;

using fmha_variant_{F_idx} = ck_tile::ComposedAttention<{F_logits} * ck_tile::LOGITS_SOFT_CAP, CK_TILE_FMHA_FWD_FAST_EXP2>;

using fmha_mask_{F_idx} = {F_mask};

using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaPipelineProblem<
    typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::QDataType,
    typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::KDataType,
    typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::VDataType,
    typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::SaccDataType,
    typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::SMPLComputeDataType,
    typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::BiasDataType,
    typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::RandValOutputDataType,
    typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::LSEDataType,
    typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::PDataType,
    typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::OaccDataType,
    typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::ODataType,
    fmha_shape_{F_idx},
    {F_mode},
    fmha_variant_{F_idx},
    fmha_mask_{F_idx},
    false,
    fmha_trait_{F_idx}>;

using fmha_pipeline_{F_idx} = {F_pipeline}<
    fmha_pipeline_problem_{F_idx}>;

using fmha_epilogue_{F_idx} =
    ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<typename FmhaFwdTypeConfig<{F_dtype}>::OaccDataType,
                                           typename FmhaFwdTypeConfig<{F_dtype}>::ODataType,
                                           {F_spad}, {F_dvpad}>>;

using fmha_kernel_{F_idx} =
    ck_tile::FmhaBatchPrefillWithPagedKVCacheKernel<fmha_pipeline_{F_idx}, fmha_epilogue_{F_idx}>;

using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout},
                        {F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false>;

#include <iostream>

template<>
float fmha_batch_prefill_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_batch_prefill_args a)
{{
    using k_ = fmha_kernel_{F_idx};
    if(s.log_level_ > 0)
        std::cout << ", " << k_::GetName() << std::flush;
    auto [kargs, grids] = fmha_batch_prefill_create_kargs_and_grids<k_>(a);
    const dim3 blocks                      = k_::BlockSize();
    constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
    return ck_tile::launch_kernel(s, ck_tile::make_kernel<kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
}}
"""

FMHA_FWD_API_FILENAME = "fmha_batch_prefill_api.cpp"
FMHA_FWD_API = """
#include <cstdio>

namespace {{
bool get_num_cus(unsigned& num_cu) {{
    int device;
    auto status = hipGetDevice(&device);
    if(status != hipSuccess) {{
        fprintf(stderr, "failed to get device");
        return false;
    }}

    hipDeviceProp_t props{{}};
    status = hipGetDeviceProperties(&props, device);
    if(status != hipSuccess) {{
        fprintf(stderr, "failed to get device properties");
        return false;
    }}

    num_cu = props.multiProcessorCount;
    return true;
}}

unsigned get_num_thread_blocks(unsigned batch, unsigned nheads, unsigned max_seqlen_q, unsigned kM0) {{
    const unsigned num_m_blocks = (max_seqlen_q + kM0 - 1) / kM0;
    const unsigned num_n_blocks = 1; // we assume that num_n_blocks is always 1

    return batch * nheads * num_m_blocks * num_n_blocks;
}}
}} // namespace

float fmha_batch_prefill(fmha_batch_prefill_traits t, fmha_batch_prefill_args a, const ck_tile::stream_config& s) {{
    float r = -1;

    [[maybe_unused]] const float min_cu_util_rate = 0.8; // minimum CU utilization rate

    unsigned num_cus;
    if (!get_num_cus(num_cus)) {{
        return r;
    }}

    [[maybe_unused]] auto get_num_blocks = [&](unsigned kM0) {{
        return get_num_thread_blocks(a.batch, a.nhead_q, a.max_seqlen_q, kM0);
    }};

{F_dispatch}
    return r;
}}
"""

FMHA_FWD_API_PER_DTYPE = """    {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{
{F_hdim_case}
    }}
"""
FMHA_FWD_API_PER_HDIM_CASE = """        {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim_v}) {{
{F_inner_dispatch}
        }}
"""

FMHA_FWD_API_INNER_DISPATCH = """            {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse})  && (t.has_dropout == {F_dropout}) && (t.qscale_type == {F_qscale_check}) &&
                        ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{
                using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false>;
                return fmha_batch_prefill_<trait_>(s, a);
            }}
"""


@dataclass
class CppConstraint:
    bool_expr: str = None

    def __str__(self):
        if self.bool_expr is None:
            return "true"
        else:
            return f"{self.bool_expr}"

    def __and__(self, other):
        return CppConstraint(f"({str(self)}) && ({str(other)})")


@dataclass
class FmhaFwdApiTrait:
    pipeline_tag: str
    # sync with fmha_fwd_traits<>, to generate fallback calls
    hdim: str
    dtype: str  # data type
    mode: str  # value from MODE_MAP
    bm0: int  # tile size along q seqlen (block size)
    bn0: int  # tile size along qk seqlen
    bk0: int  # tile size along qk gemm unroll
    bn1: int  # tile size along v head_dim
    bk1: int  # tile size along kv gemm unroll
    bk0max: int
    vlayout: str
    logits: str
    mask: str
    bias: str  #
    lse: str  #
    dropout: str
    qscale: str  #
    spad: str
    skpad: str
    dpad: str
    dvpad: str
    constraint: CppConstraint

    @property
    def name(self) -> str:
        return (
            f"{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-"
            + f"{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.qscale}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}"
        )

    @property
    def scheck(self) -> str:
        if self.mode == "group":
            return "true/*group mode spad always true*/"  # group mode only generate spad/skpad == true
        if self.pipeline_tag == "qr_async":
            if self.spad == "t":
                return "true"  # always support
            else:
                return "true"
        elif self.pipeline_tag in ["qr"]:
            if self.spad == "t":
                return f"true /*a.seqlen_q % {self.bm0} != 0*/"  # TODO: order of get_pipelines() matters! (ugly)
            else:
                return f"a.seqlen_q % {self.bm0} == 0"
        else:
            assert False

    @property
    def skcheck(self) -> str:
        if self.mode == "group":
            return "true/*group mode skpad always true*/"  # group mode only generate spad/skpad == true
        if self.pipeline_tag == "qr_async":
            if self.skpad == "t":
                return f"a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0"
            else:
                return f"a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0"
        elif self.pipeline_tag in ["qr", "qr_fp8"]:
            if self.skpad == "t":
                return f"true /*a.seqlen_k % {self.bn0} != 0*/"  # TODO: order of get_pipelines() matters! (ugly)
            else:
                return f"a.seqlen_k % {self.bn0} == 0"
        else:
            assert False

    @property
    def dcheck(self) -> str:
        if self.pipeline_tag == "qr_async":
            vec = int((32 * 4) / DTYPE_BITS[self.dtype])
            if self.dpad == "t":
                return f"a.hdim_q % {vec} == 0"
            else:
                assert False
        elif self.pipeline_tag in ["qr"]:
            bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max]
            if self.dpad == "t":
                return f"true /*a.hdim_q % {bk0submax} != 0*/"  # TODO: order of get_pipelines() matters! (ugly)
            else:
                return f"a.hdim_q % {bk0submax} == 0"
        else:
            assert False

    @property
    def dvcheck(self) -> str:
        if self.pipeline_tag == "qr_async":
            vec = int((32 * 4) / DTYPE_BITS[self.dtype])
            if self.dvpad == "t":
                return f"a.hdim_v % {vec} == 0"
            else:
                assert False
        elif self.pipeline_tag in ["qr"]:
            bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max]
            if self.dvpad == "t":
                return f"true /*a.hdim_v % {bk0submax} != 0*/"  # TODO: order of get_pipelines() matters! (ugly)
            else:
                return f"a.hdim_v % {bk0submax} == 0"
        else:
            assert False


@dataclass
class FmhaFwdPipeline:
    tag: str

    F_vlayout: str  # row/col
    F_spad: str  # true/false
    F_skpad: str  #
    F_dpad: str  #
    F_dvpad: str  #
    F_logits: str  # t/f
    F_bias: str  # true/false
    F_lse: str  #
    F_dropout: str  #
    F_qscale: str  # no/pertensor
    F_mask: str  # value from MASK_MAP
    F_constraint: CppConstraint = field(default_factory=lambda: CppConstraint())

    @property
    def name(self) -> str:
        def pad_name() -> str:
            n = ""
            if self.F_spad == "t":
                n += "s"
            if self.F_skpad == "t":
                n += "sk"
            if self.F_dpad == "t":
                n += "d"
            if self.F_dvpad == "t":
                n += "dv"
            if n != "":
                n = "p" + n
            return n

        pn = pad_name()
        n = f"{self.tag}_v{self.F_vlayout[0]}"
        if pn != "":
            n += f"_{pn}"
        else:
            n += "_npad"

        if self.F_logits == "t":
            n += "_logits"
        else:
            n += "_nlogits"

        if self.F_bias != "no":
            n += f"_{self.F_bias}"
        else:
            n += "_nbias"

        if self.F_mask[0:2] == "s_":
            if self.F_mask == "s_mask":
                n += "_mask"
            else:
                n += "_nmask"
        else:
            if self.F_mask != "no":
                n += f"_m{self.F_mask[0]}"
            else:
                n += "_nmask"

        if self.F_lse == "t":
            n += "_lse"
        else:
            n += "_nlse"

        if self.F_dropout == "t":
            n += "_dropout"
        else:
            n += "_ndropout"

        if self.F_qscale != "no":
            n += f"_{self.F_qscale}"
        else:
            n += "_nqscale"
        return n


class FmhaFwdApiPool:
    def __init__(self, mask_impl):
        self.pool = dict()
        self.mask_impl = mask_impl

    def register_traits(self, trait: FmhaFwdApiTrait) -> None:
        # TODO: do we need to check duplication?
        if trait.dtype not in self.pool.keys():
            self.pool[trait.dtype] = dict()
        if trait.hdim not in self.pool[trait.dtype].keys():
            self.pool[trait.dtype][trait.hdim] = list()

        self.pool[trait.dtype][trait.hdim].append(copy.copy(trait))

    @property
    def api(self) -> str:
        per_dtypes = str()
        for i, dtype in enumerate(self.pool.keys()):
            per_hdim_case = str()
            for j, hdim in enumerate(self.pool[dtype].keys()):
                traits = self.pool[dtype][hdim]
                inners = str()
                for k, trait in enumerate(traits):
                    if_k = "if" if k == 0 else "else if"
                    inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(
                        F_if=if_k,
                        F_mode=MODE_MAP[trait.mode],
                        F_vlayout=LAYOUT_MAP[trait.vlayout],
                        F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag],
                        F_logits=BOOL_MAP[trait.logits],
                        F_mask=get_mask_map(self.mask_impl)[trait.mask],
                        F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask],
                        F_bias_check=BIAS_CHECK_MAP[trait.bias],
                        F_bias=BIAS_MAP[trait.bias],
                        F_lse=BOOL_MAP[trait.lse],
                        F_dropout=BOOL_MAP[trait.dropout],
                        F_qscale_check=QSCALE_CHECK_MAP[trait.qscale],
                        F_qscale=QSCALE_MAP[trait.qscale],
                        F_scheck=trait.scheck,
                        F_skcheck=trait.skcheck,
                        F_dcheck=trait.dcheck,
                        F_dvcheck=trait.dvcheck,
                        F_constraint=trait.constraint,
                        F_spad=BOOL_MAP[trait.spad],
                        F_skpad=BOOL_MAP[trait.skpad],
                        F_dpad=BOOL_MAP[trait.dpad],
                        F_dvpad=BOOL_MAP[trait.dvpad],
                        F_bm0=trait.bm0,
                        F_bn0=trait.bn0,
                        F_bk0=trait.bk0,
                        F_bn1=trait.bn1,
                        F_bk1=trait.bk1,
                        F_bk0max=trait.bk0max,
                        F_hdim=hdim,
                        F_dtype=FWD_DTYPE_MAP[dtype],
                    )
                if_j = "if" if j == 0 else "else if"
                per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(
                    F_if=if_j, F_hdim=hdim, F_hdim_v=trait.bn1, F_inner_dispatch=inners
                )
            if_i = "if" if i == 0 else "else if"
            per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(
                F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case
            )
        if not per_dtypes:
            # empty string we add some ignore to suppress warning in api
            per_dtypes += "    (void)t; (void)s; (void)a;"
        return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch=per_dtypes)


@dataclass
class FmhaFwdTileSize:
    F_bm0: int  # tile size along q seqlen (block size)
    F_bn0: int  # tile size along k seqlen
    F_bk0: int  # tile size along qk gemm unroll
    F_bn1: int  # tile size along v head_dim
    F_bk1: int  # tile size along kv gemm unroll
    F_bk0max: int  # total length of K0, used for pipeline that need load Q at once (or repeately load Q as a whole tile)
    F_rm0: int  # number of warps for gemm0 along q seqlen
    F_rn0: int  # number of warps for gemm0 along k seqlen
    F_rk0: int  # number of warps for gemm0 along head dim q (not used)
    F_rm1: int  # number of warps for gemm1 along q seqlen
    F_rn1: int  # number of warps for gemm1 along head dim v
    F_rk1: int  # number of warps for gemm1 along k seqlen (not used)
    F_wm0: int  # gemm0 warp size along m
    F_wn0: int  # gemm0 warp size along n
    F_wk0: int  # gemm0 warp size along k
    F_wm1: int  # gemm1 warp size along m
    F_wn1: int  # gemm1 warp size along n
    F_wk1: int  # gemm1 warp size along k
    F_occupancy: int  # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy
    F_constraint: CppConstraint = field(default_factory=lambda: CppConstraint())

    @property
    def name(self) -> str:
        return (
            f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bn1}x{self.F_bk1}x{self.F_bk0max}"
            + f"_r{self.F_rm0}x{self.F_rn0}x{self.F_rk0}_r{self.F_rm1}x{self.F_rn1}x{self.F_rk1}"
            + f"_w{self.F_wm0}x{self.F_wn0}x{self.F_wk0}_w{self.F_wm1}x{self.F_wn1}x{self.F_wk1}"
            + ("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}")
        )


@dataclass
class FmhaFwdKernel:
    F_idx: int  # this is not a tunable, but a counter to differentiate symbol
    F_hdim: int  # hdim
    F_dtype: str  # data type
    F_mode: str  # value from MODE_MAP
    F_tile: FmhaFwdTileSize
    F_pipeline: FmhaFwdPipeline
    mask_impl: str

    @property
    def template(self) -> str:
        return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_KERNEL_BODY.format(
            F_idx=self.F_idx,
            F_hdim=self.F_hdim,
            F_dtype=FWD_DTYPE_MAP[self.F_dtype],
            F_bm0=self.F_tile.F_bm0,
            F_bn0=self.F_tile.F_bn0,
            F_bk0=self.F_tile.F_bk0,
            F_bn1=self.F_tile.F_bn1,
            F_bk1=self.F_tile.F_bk1,
            F_bk0max=self.F_tile.F_bk0max,
            F_rm0=self.F_tile.F_rm0,
            F_rn0=self.F_tile.F_rn0,
            F_rk0=self.F_tile.F_rk0,
            F_rm1=self.F_tile.F_rm1,
            F_rn1=self.F_tile.F_rn1,
            F_rk1=self.F_tile.F_rk1,
            F_wm0=self.F_tile.F_wm0,
            F_wn0=self.F_tile.F_wn0,
            F_wk0=self.F_tile.F_wk0,
            F_wm1=self.F_tile.F_wm1,
            F_wn1=self.F_tile.F_wn1,
            F_wk1=self.F_tile.F_wk1,
            F_vlayout=LAYOUT_MAP[self.F_pipeline.F_vlayout],
            F_spad=BOOL_MAP[self.F_pipeline.F_spad],
            F_skpad=BOOL_MAP[self.F_pipeline.F_skpad],
            F_dpad=BOOL_MAP[self.F_pipeline.F_dpad],
            F_dvpad=BOOL_MAP[self.F_pipeline.F_dvpad],
            F_logits=BOOL_MAP[self.F_pipeline.F_logits],
            F_bias=BIAS_MAP[self.F_pipeline.F_bias],
            F_lse=BOOL_MAP[self.F_pipeline.F_lse],
            F_dropout=BOOL_MAP[self.F_pipeline.F_dropout],
            F_qscale=QSCALE_MAP[self.F_pipeline.F_qscale],
            F_occupancy=self.F_tile.F_occupancy,
            F_pipeline_enum=PIPELINE_ENUM_MAP[self.F_pipeline.tag],
            F_mask=get_mask_map(self.mask_impl)[self.F_pipeline.F_mask],
            F_mode=MODE_MAP[self.F_mode],
            F_pipeline=FMHA_BATCH_PREFILL_PIPELINE_MAP[self.F_pipeline.tag],
        )

    @property
    def name(self) -> str:
        # TODO: we don't encode idx here
        return (
            f"fmha_batch_prefill_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_"
            + self.F_tile.name
            + "_"
            + self.F_pipeline.name
        )

    @property
    def filename(self) -> str:
        return self.name + ".cpp"

    def api_trait(self) -> FmhaFwdApiTrait:
        return FmhaFwdApiTrait(
            pipeline_tag=self.F_pipeline.tag,
            hdim=str(self.F_hdim),
            dtype=self.F_dtype,
            mode=self.F_mode,
            bm0=self.F_tile.F_bm0,
            bn0=self.F_tile.F_bn0,
            bk0=self.F_tile.F_bk0,
            bn1=self.F_tile.F_bn1,
            bk1=self.F_tile.F_bk1,
            bk0max=self.F_tile.F_bk0max,
            vlayout=self.F_pipeline.F_vlayout,
            mask=self.F_pipeline.F_mask,
            logits=self.F_pipeline.F_logits,
            bias=self.F_pipeline.F_bias,
            lse=self.F_pipeline.F_lse,
            dropout=self.F_pipeline.F_dropout,
            qscale=self.F_pipeline.F_qscale,
            spad=self.F_pipeline.F_spad,
            skpad=self.F_pipeline.F_skpad,
            dpad=self.F_pipeline.F_dpad,
            dvpad=self.F_pipeline.F_dvpad,
            constraint=self.F_tile.F_constraint & self.F_pipeline.F_constraint,
        )


class KernelComponentFactory:
    @staticmethod
    def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]:
        if dtype == "fp16" or dtype == "bf16":
            return {
                128 : [FmhaFwdTileSize(128, 128, 32, 128, 32,  128,  4, 1, 1,  4, 1, 1,  32, 32, 16,  32, 32, 16,  -1)],
            }  # fmt: skip
        else:
            return None

    @staticmethod
    def get_pipelines(dtype, hdim, receipt, mask_impl) -> List[FmhaFwdPipeline]:
        # this function will populate a list possible pipelines
        # TODO: the order of List matters! the later in this list will be also be checked later
        # TODO: currently for qr pipeline, let 't' padding to appear later!!
        # TODO: how to design this more generic?
        qscale = "no"
        pipelines = []
        if dtype in ["fp16", "bf16"]:
            for logits, mask, bias, lse, dropout in itertools.product(
                ["t", "f"],
                get_mask_map(mask_impl).keys(),
                BIAS_MAP.keys(),
                ["t", "f"],
                ["t", "f"],
            ):
                pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "f", "t", "t", logits, bias, lse, dropout, qscale, mask))  # fmt: skip
                pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask))  # fmt: skip
                # pipelines.append(FmhaFwdPipeline("qr_async", "col", "t", "f", "t", "t", logits, bias, lse, dropout, qscale, mask))  # fmt: skip
                # pipelines.append(FmhaFwdPipeline("qr_async", "col", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask))  # fmt: skip
        else:
            assert False
        return pipelines


class CustomFactory(KernelComponentFactory):
    @staticmethod
    def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]:
        result = KernelComponentFactory.get_hdim_tile_size_dict(dtype)
        if dtype == "fp16" or dtype == "bf16":
            if 128 in result.keys():
                result[128].insert(0, FmhaFwdTileSize( 64, 128, 64, 128, 64,  128,  4, 1, 1,  4, 1, 1,  16, 16, 16,  16, 16, 16,  -1, CppConstraint("get_num_blocks(128) < num_cus * min_cu_util_rate")))  # fmt: skip
        return result


def get_fwd_blobs(
    kernel_filter: Optional[str], receipt, optdim_list, mask_impl
) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]:
    # TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad
    #       support this in future

    gen = list()
    api_pool = FmhaFwdApiPool(mask_impl)

    for dtype in FWD_DTYPE_MAP.keys():
        d = CustomFactory.get_hdim_tile_size_dict(dtype)
        if d is None:
            continue
        # for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]):
        for (hdim, tiles), mode in itertools.product(d.items(), MODE_MAP.keys()):
            for tile, pipeline in itertools.product(
                tiles, CustomFactory.get_pipelines(dtype, hdim, receipt, mask_impl)
            ):
                if mode == "group":
                    if pipeline.F_spad != "t" or pipeline.F_skpad != "t":
                        # in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not
                        continue
                if hdim == 192 and tile.F_bn1 == 128:
                    # NOTE: this is used to speedup deepseek prefill case, we don't gen training
                    if (
                        pipeline.F_bias != "no"
                        or pipeline.F_lse == "t"
                        or pipeline.F_dropout == "t"
                    ):
                        continue
                # logits_soft_cap is only allowed if no bias
                if not (
                    (pipeline.F_logits == "t" and pipeline.F_bias == "no")
                    or pipeline.F_logits == "f"
                ):
                    continue
                k = FmhaFwdKernel(
                    F_idx=0,
                    F_hdim=hdim,
                    F_dtype=dtype,
                    F_mode=mode,
                    F_tile=tile,
                    F_pipeline=pipeline,
                    mask_impl=mask_impl,
                )
                if kernel_filter != "":
                    if not fnmatch.fnmatch(k.name, kernel_filter):
                        continue
                if optdim_list != [-1]:
                    if hdim not in optdim_list:
                        continue
                # 2 - Flash attention integration
                if receipt in (2, 3):
                    cond = dtype in ["fp16", "bf16"]
                    cond &= pipeline.F_vlayout == "row"
                    cond &= pipeline.F_bias in ["no", "alibi"]
                    cond &= pipeline.F_qscale == "no"
                    if not cond:
                        continue
                # PyTorch integration
                elif receipt == 4:
                    cond = dtype in ["fp16", "bf16"]
                    cond &= pipeline.F_vlayout == "row"
                    cond &= pipeline.F_bias in ["no", "bias"]
                    cond &= pipeline.F_qscale == "no"
                    if not cond:
                        continue
                # Aiter(mha_fwd) integration
                elif receipt == 100:
                    cond = dtype in ["fp16", "bf16"]
                    cond &= mode == "batch"
                    cond &= pipeline.F_vlayout == "row"
                    cond &= pipeline.F_qscale == "no"
                    if not cond:
                        continue
                # Aiter(mha_batch_prefill) integration
                elif receipt == 200:
                    cond = dtype in ["fp16", "bf16"]
                    cond &= mode == "group"
                    cond &= pipeline.F_vlayout == "row"
                    cond &= pipeline.F_qscale == "no"
                    if not cond:
                        continue
                # aiter::mha_batch_prefill C++ api integration
                elif receipt == 600:
                    cond = dtype in ["fp16", "bf16"]
                    cond &= mode == "group"
                    cond &= pipeline.F_vlayout == "row"
                    cond &= pipeline.F_qscale == "no"
                    if not cond:
                        continue

                # fp32 only
                if receipt == 800 or receipt == 801:
                    cond = dtype == "fp32"
                    if not cond:
                        continue

                api_pool.register_traits(k.api_trait())
                gen.append(k)

    return (api_pool, gen)


def write_single_fwd_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None:
    update_file(autogen_dir / kernel.filename, kernel.template)


def write_fwd_api(api_pool: FmhaFwdApiPool, autogen_dir: Path) -> None:
    update_file(autogen_dir / FMHA_FWD_API_FILENAME, api_pool.api)


def write_blobs(
    targets: List[str],
    output_dir: Path,
    kernel_filter: str,
    receipt,
    optdim_list,
    mask_impl,
) -> None:
    api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl)
    for kernel in kernels:
        write_single_fwd_kernel(kernel, output_dir)
    write_fwd_api(api_pool, output_dir)


def list_blobs(
    targets: List[str],
    file_path: Path,
    kernel_filter: str,
    receipt,
    optdim_list,
    mask_impl,
) -> None:
    with file_path.open("a") as f:
        _, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl)
        for kernel in kernels:
            f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n")
        f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME) + "\n")
